package com.stars.easyms.logging.interceptor;

import com.alibaba.ttl.TransmittableThreadLocal;
import com.stars.easyms.base.trace.EasyMsTraceHelper;
import com.stars.easyms.base.trace.EasyMsTraceSynchronizationManager;
import org.aopalliance.intercept.MethodInvocation;
import org.slf4j.MDC;
import org.springframework.aop.support.AopUtils;
import org.springframework.core.BridgeMethodResolver;
import org.springframework.core.task.AsyncTaskExecutor;
import org.springframework.lang.Nullable;
import org.springframework.scheduling.annotation.AnnotationAsyncExecutionInterceptor;
import org.springframework.util.ClassUtils;

import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;

/**
 * <p>className: EasyMsLoggingAnnotationAsyncExecutionInterceptor</p>
 * <p>description: EasyMs的logging模块自定义AnnotationAsyncExecutionInterceptor类</p>
 *
 * @author guoguifang
 * @version 1.7.1
 * @date 2020/12/19 10:55 上午
 */
public class EasyMsLoggingAnnotationAsyncExecutionInterceptor extends AnnotationAsyncExecutionInterceptor {

    public EasyMsLoggingAnnotationAsyncExecutionInterceptor() {
        super(null);
    }

    @Override
    @Nullable
    public Object invoke(final MethodInvocation invocation) throws Throwable {
        Class<?> targetClass = (invocation.getThis() != null ? AopUtils.getTargetClass(invocation.getThis()) : null);
        Method specificMethod = ClassUtils.getMostSpecificMethod(invocation.getMethod(), targetClass);
        final Method userDeclaredMethod = BridgeMethodResolver.findBridgedMethod(specificMethod);

        AsyncTaskExecutor executor = determineAsyncExecutor(userDeclaredMethod);
        if (executor == null) {
            throw new IllegalStateException(
                    "No executor specified and no default executor set on AsyncExecutionInterceptor either");
        }

        final Map<String, String> copyOfContextMap = MDC.getCopyOfContextMap();
        final Object capture = TransmittableThreadLocal.Transmitter.capture();
        Callable<Object> task = () -> {
            TransmittableThreadLocal.Transmitter.restore(capture);
            MDC.setContextMap(copyOfContextMap);
            try {
                return EasyMsTraceHelper.setTraceIdIfNecessary(t -> {
                    EasyMsTraceSynchronizationManager.entryAsyncThread(t);
                    Object result = invocation.proceed();
                    if (result instanceof Future) {
                        return ((Future<?>) result).get();
                    }
                    return null;
                });
            } catch (ExecutionException ex) {
                handleError(ex.getCause(), userDeclaredMethod, invocation.getArguments());
            } catch (Throwable ex) {
                handleError(ex, userDeclaredMethod, invocation.getArguments());
            } finally {
                EasyMsTraceSynchronizationManager.clearAsyncId();
                MDC.clear();
                TransmittableThreadLocal.Transmitter.clear();
            }
            return null;
        };

        return doSubmit(task, executor, invocation.getMethod().getReturnType());
    }

}
